import safety_gym
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import time
import numpy as np
import torch
import torch.nn as nn
import gym
import sys
import os
os.chdir('/home/user/safety-starter-agents/safe_rl/PPO_Lagrangian_PyTorch/data/PPO-POINT-VAL/')
sys.path.append('/home/user/safety-starter-agents/safe_rl/PPO_Lagrangian_PyTorch/data/PPO-POINT-TRAIN/')
import core
# from utils.logx import EpochLogger
from utils.mpi_tools import mpi_fork, proc_id, num_procs, mpi_sum
torch.autograd.set_detect_anomaly(True)
import sysv_ipc
import torch.nn.functional as F
import copy
import multiprocessing
import pandas as pd
import gc

class Safety_NN(nn.Module):
    def __init__(self, n_state, n_class):
        super(Safety_NN, self).__init__()
        self.layer1 = nn.Linear(n_state, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_class)
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)


n_sample_of_action = 252
n_class = 2
n_rank = 2
n_action = 2
n_NN = 1
n_observations = 60



def validation_for_one_process(exp_name, thread_order, key_offset, agent, n_epoch, scale, default_action_margin, context,
                               local_steps_per_epoch, render, hazard_check_margin, max_ep_len, val_data_path, ckpt_info_sub):
    if thread_order<4:
        device = torch.device("cuda:1")
    elif thread_order<9:
        device = torch.device("cuda:2")
    else:
        device = torch.device("cuda:3")
    print(device)
    SNN_list = [Safety_NN(n_observations+n_action, n_class) for _ in range(n_NN)]
    for i in range(n_NN): SNN_list[i].eval()
    ckpt_path = exp_name + "/checkpoint_scale/" + str(scale) +"/pof-checkpoint-1_epoch-" + ckpt_info_sub + ".pt"
    try: # (1)agent, (2)intest, (3)classifier
        backup = torch.load(ckpt_path)
        storage_intest_number = backup["storage_intest_number"] #####
        storage_intest_state_list = backup["storage_intest_state_list"]
        storage_intest_label_list = backup["storage_intest_label_list"]
        print(storage_intest_state_list.shape)
        agent.load_state_dict(backup["ac_ppo"])
        agent.eval() 
        for i in range(n_NN):
            temp_name = "SNN_" + str(i)
            SNN_list[i].load_state_dict(backup[temp_name])
            SNN_list[i].eval()
        print("    LOAD POF %s" %(ckpt_path))
    except Exception as e:
        print(e)
        exit()
    result_intest_list = list()
    for nn_order in range(n_NN):
        storage_intest_state_list[nn_order][:,60:62]=np.clip(storage_intest_state_list[nn_order][:,60:62],-1,1)
        for i in range(storage_intest_state_list[nn_order].shape[0]):
            if storage_intest_state_list[nn_order][i][60]>0: storage_intest_state_list[nn_order][i][60]=1.0
            else: storage_intest_state_list[nn_order][i][60]=-1.0
        result_intest_tmp = SNN_list[nn_order].to(device)(torch.FloatTensor(storage_intest_state_list[nn_order]).to(device)).detach().cpu().numpy()/2.
        result_intest_list.append(np.concatenate((storage_intest_label_list[nn_order], result_intest_tmp), axis=1))
        gc.collect()
        torch.cuda.empty_cache()
    print(str(thread_order)+"start")
    xi=10.0/scale
    env_name = 'Safexp-PointGoal1-v0'
    env = gym.make(env_name)
    # global variable section copied from cpp.
    normt_init = True # only for 1 condition checking
    normaltable_global = np.zeros((n_NN,2,3,2), dtype=int)
    posterior_global = np.zeros((n_NN,3,2), dtype=float)
    ans_sum = np.zeros((n_NN,2), dtype=int)
    # IPC setting
    key_1, key_2 = (key_offset + thread_order)*1157, (key_offset + thread_order)*1257
    pof_shared_memory_1 = sysv_ipc.SharedMemory(key=key_1, flags=sysv_ipc.IPC_CREAT, size=50000) #@@ 900000
    sem_1 = sysv_ipc.Semaphore(key_1, 0)
    pof_shared_memory_2 = sysv_ipc.SharedMemory(key=key_2, flags=sysv_ipc.IPC_CREAT, size=100) #@@ 450000
    sem_2 = sysv_ipc.Semaphore(key_2, 0)
    
    def POF_batch_making(state, a_candidate_dense):
        global n_sample_of_action
        nonlocal storage_intest_label_list
        result_batch_list = list()
        result_sample_array = np.empty([0,n_class])
        for nn_order in range(n_NN):
            state_action_batch_temp = np.concatenate((np.repeat(state.reshape(1,-1), repeats=n_sample_of_action, axis=0), a_candidate_dense.reshape(n_sample_of_action,-1)), axis=1)
            state_action_batch_temp[:,60:62]=np.clip(state_action_batch_temp[:,60:62],-1,1)
            for i in range(state_action_batch_temp.shape[0]):
                if state_action_batch_temp[i][60]>0: state_action_batch_temp[i][60]=1.0
                else: state_action_batch_temp[i][60]=-1.0
            result_batch_list.append(SNN_list[nn_order](torch.FloatTensor(state_action_batch_temp).to(device)).to("cpu"))
            result_sample_tmp= result_batch_list[nn_order].detach().numpy()
            result_sample_array = np.append(result_sample_array, result_sample_tmp, axis=0)       
        return result_sample_array


    def POF_update_loading(shared_string):
        shared_string = shared_string.split("\n")
        temp = shared_string[0]
        pof_action = int(temp)
        temp = shared_string[1]
        default_action_using = str(temp)
        return pof_action, default_action_using
                                

    def POF_output_saving(writing_mode, x_vel, hazard_check, epoch, step, mu, std, result_sample_array, result_intest_list) -> str:
        n_status = 1 # "val"
        if writing_mode == "txt":
            pof_output_path = val_data_path + "POF_OUTPUT_shm_"+str(epoch)+"_"+str(step)+".txt"
            file =  open(pof_output_path, "w")
            file.write(str(n_status))
            file.write("\n")
            file.write(str(result_intest_list[0].shape[0]))
            file.write("\n")
            if hazard_check and x_vel < 0: file.write(str(-10020))
            elif hazard_check and x_vel > 0: file.write(str(-10102))
            elif x_vel < -default_action_margin: file.write(str(-102))  
            elif x_vel > default_action_margin: file.write(str(-20)) 
            else: file.write(str(-61))  
            file.write("\n")
            file.write(str(mu))
            file.write("\n")
            file.write(str(std))
            file.write("\n")  
            file.write(str(n_class))
            file.write("\n")
            for i in range(result_sample_array.shape[0]):
                for j in range(result_sample_array.shape[1]):                
                    file.write(str(result_sample_array[i][j]))
                    file.write(" ")
                file.write("\n")
            for i in range(len(context)):
                file.write(str(context[i]))
                file.write(" ")
            file.write("\n")
            file.write(str(min_output_class))
            file.write("\n")
            file.close()
            return "txt complete"
        elif writing_mode=="shm":
            lines = [
                str(n_status),
                str(result_intest_list[0].shape[0]),
            ]
            if hazard_check and x_vel < 0: lines.append("-10020")
            elif hazard_check and x_vel > 0: lines.append("-10102")
            elif x_vel < -default_action_margin: lines.append("-102")
            elif x_vel > default_action_margin: lines.append("-20")
            else: lines.append("-61")
            lines.extend([
                str(max(min(mu[0].detach().item(), 4.0), -4.0)),
                str(max(min(mu[1].detach().item(), 4.0), -4.0)),
                str(std[0].detach().item()),
                str(std[1].detach().item()),
                str(n_class),
            ])
            lines.extend(map(str, result_sample_array.flatten()))
            lines.extend(map(str, context))
            lines.extend([str(min_output_class)])
            shared_string = "\n".join(lines) + "\n"
            return shared_string
        else: return "nothing complete"


    def pof_calc_posterior(normaltable_tmp, posterior_tmp, n_intest_safe, n_intest_unsafe):
        for i in range(n_NN):
            prior = np.zeros(2, dtype=float)
            prior[0] = n_intest_safe/(n_intest_safe + n_intest_unsafe)
            prior[1] = n_intest_unsafe/(n_intest_safe + n_intest_unsafe)
            for j in range(n_class):
                denominator = 0.00000000001
                for k in range(n_rank): denominator += prior[k] * normaltable_tmp[i][k][j][0] / (ans_sum[i][k] + 0.00000000001)
                for k in range(n_rank): posterior_tmp[i][j][k] = ((prior[k] * normaltable_tmp[i][k][j][1] + 0.00000000001) / (ans_sum[i][k] + 0.00000000001)+ 0.00000000001) / denominator
                for k in range(n_rank): 
                    if posterior_tmp[i][j][k]>1: posterior_tmp[i][j][k]=1


    def pof_calc_normaltable_and_posterior(normaltable_tmp, posterior_tmp, xi, result_intest_list, n_intest_safe, n_intest_unsafe):
        # calculate normaltable
        for i in range(n_NN): 
            n_test_case_pm = n_intest_safe + n_intest_unsafe
            for j in range(n_test_case_pm):
                max = -5000000
                secmax = -5000000
                ans_tmp = int(result_intest_list[i][j][0])
                ans_sum[i][ans_tmp] += 1
                for k in range(1,n_class+1):
                    if result_intest_list[i][j][k] > max:
                        secmax = max
                        max = result_intest_list[i][j][k]
                    elif result_intest_list[i][j][k] > secmax: secmax = result_intest_list[i][j][k]
                for k in range(1,n_class+1):
                    if result_intest_list[i][j][k] + xi > max: normaltable_tmp[i][ans_tmp][k-1][1] += 1
                    if result_intest_list[i][j][k] - xi > secmax: normaltable_tmp[i][ans_tmp][k-1][0] += 1
        pof_calc_posterior(normaltable_tmp, posterior_tmp, n_intest_safe, n_intest_unsafe)
        
    


    def pof_validation(state_reshape, x_vel, hazard_check, mu, std, xi):
        nonlocal storage_intest_label_list, result_intest_list
        nonlocal pof_shared_memory_1, pof_shared_memory_2
        nonlocal normt_init
        nonlocal min_output_class
        result_sample_array = POF_batch_making(state_reshape, a_candidate_dense)
        if normt_init: # only once
            pof_calc_normaltable_and_posterior(normaltable_global, posterior_global, xi, result_intest_list, int(result_intest_list[0].shape[0]/2),int(result_intest_list[0].shape[0]/2) )
            normt_init = False
        if  min_output_class == -1:
            min_posterior = 1e9
            for j in range(n_class):
                if posterior_global[0][j][1] < min_posterior:
                    min_posterior = posterior_global[0][j][1]
                    min_output_class = j
            if (min_output_class == -1): 
                print("########################################### Err")
                exit()
        shared_string_1 = POF_output_saving("shm", x_vel, hazard_check, 0, 0, mu, std, result_sample_array, result_intest_list)
        pof_shared_memory_1.write(shared_string_1.encode())
        sem_1.V() ## S1-1 & S2-0
        sem_2.P() ## S1-0 & S2-0
        shared_string_2 = pof_shared_memory_2.read().decode()
        selected_action, default_action_using = POF_update_loading(shared_string_2)
        return selected_action, default_action_using

    pof_cost_path = val_data_path + "POF_COST"+str(thread_order)+".txt"
    pof_reward_path = val_data_path + "POF_PF"+str(thread_order)+".txt"
    pof_cv_path = val_data_path + "POF_CV"+str(thread_order)+".txt"
    pof_act_path = val_data_path + "POF_ACT"+str(thread_order)+".txt"
    costfile = open(pof_cost_path, "a")
    pffile = open(pof_reward_path, "a")
    cvfile = open(pof_cv_path, "a")
    acfile = open(pof_act_path, "a")
    
    state, reward, done, cost, ep_ret, ep_cost, ep_len = env.reset(), 0, False, 0, 0, 0, 0
    acc_cost = 0
    acc_reward = 0
    min_output_class = -1
    for epoch in range(n_epoch):
        cv_denom = [0]
        cv_denom_counter = 0
        cv_counter = 0
        cv_index = 0
        for t in range(local_steps_per_epoch):
            if render and proc_id()==0: env.render()
            a_candidate, a_candidate_dense, v, vc, _, mu, std = agent.stepv2(torch.as_tensor(state, dtype=torch.float32))
            state_reshape = state.reshape((1,)+state.shape)
            x_vel = state[57]
            hazard_check = not all(one < hazard_check_margin for one in state[22:38]) # self.observe_hazards
            selected_action, default_action_using = pof_validation(state_reshape, x_vel, hazard_check, mu, std, xi)
            if default_action_using == "0": cv_denom.append(cv_denom[cv_index] + 1)
            else: 
                cv_denom.append(cv_denom[cv_index])
            
            if selected_action == -20:
                action = [-1.,0.]
                selected_action = -1000
            elif selected_action == -102:
                action = [1.,0.]
                selected_action = -2000
            elif selected_action == -61:
                action = [0.,0.]
                selected_action = -3000
            else:
                assert(selected_action >= 0)
                action = a_candidate[selected_action]
            acfile.write(str(selected_action) + "\n")
            next_state, reward, done, info = env.step(action)
            cost = info.get('cost', 0)
            acc_cost += cost
            acc_reward += reward
            state = next_state
            ep_ret += reward
            ep_cost += cost
            ep_len += 1
            if cost != 0: 
                cv_counter += cv_denom[cv_index + 1] - cv_denom[max(0, cv_index - 59)]
                cv_denom_counter += cv_denom[cv_index + 1]
                cv_denom = [0]
                cv_index = 0
                state, reward, done, cost, _, _, _ = env.reset(), 0, False, 0, 0, 0, 0
            else: cv_index += 1

            terminal = done or (ep_len == max_ep_len)
            if terminal:
                print("RESET at epoch:%d, local_epoch:%d" %(epoch, t+1))
                if ep_len == max_ep_len:
                    costfile.write(str(acc_cost) + "\n")
                    pffile.write(str(acc_reward) + "\n")
                    if cost == 0: cv_denom_counter+=cv_denom[cv_index]
                    cvfile.write(str(float(cv_counter/(cv_denom_counter+0.00000000001))) + "\n")
                    acc_cost = 0
                    acc_reward = 0
                    cv_denom = [0]
                    cv_denom_counter = 0
                    cv_counter = 0
                    cv_index = 0
                    costfile.close()
                    pffile.close()
                    cvfile.close()
                    acfile.close()
                    costfile = open(pof_cost_path, "a")
                    pffile = open(pof_reward_path, "a")
                    cvfile = open(pof_cv_path, "a")
                    acfile = open(pof_act_path, "a")
                state, reward, done, cost, ep_ret, ep_cost, ep_len = env.reset(), 0, False, 0, 0, 0, 0
    costfile.close()
    pffile.close()
    cvfile.close()
    acfile.close()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Safexp-PointGoal1-v0')
    parser.add_argument('--hid', type=int, default=256)
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=10000) #####
    parser.add_argument('--local_steps_per_epoch', type=int, default=1000)
    parser.add_argument('--len', type=int, default=1000)
    parser.add_argument('--exp_name', type=str, default='error')
    parser.add_argument('--checkpoint', type=str, default='-1')
    parser.add_argument('--render', action='store_true')
    args = parser.parse_args()
    
    val_exp_set = {"pofval_poftraining_ppo_point_nolag": 1,
                   "pofval_poftraining_ppo_point_lag1.5": 2210} #####
    costlim = {"pofval_poftraining_ppo_point_nolag": 2.5,
                   "pofval_poftraining_ppo_point_lag1.5": 1.5} #####
    
    if args.exp_name in val_exp_set:
        key_offset = val_exp_set[args.exp_name]
    else:
        print(f"'{args.exp_name}'is invalid validation experiment. Please check 'val_exp_set'.")

    from utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)
    
    # Seed setting
    seed = args.seed
    seed += 10000 * proc_id()
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
    # ETC setting
    env_fn = lambda: gym.make(args.env)
    ac_kwargs = dict(hidden_sizes=[args.hid]*args.l)
    render = args.render
    epochs = args.epochs
    local_steps_per_epoch = args.local_steps_per_epoch
    max_ep_len = args.len
    storage_intest_number_array = [5000000, 5000000]
    default_action_margin = 0.05
    hazard_check_margin = 0.935
    
    context = [60*costlim[args.exp_name],1000-60*costlim[args.exp_name]] #####
    
    # Main setting
    exp_name_split = args.exp_name.split('_')
    val_info = exp_name_split[0]
    ckpt_info = exp_name_split[1]
    agent_info = exp_name_split[2]
    task_info = exp_name_split[3]
    if len(exp_name_split) == 5: agent_info_sub = exp_name_split[4]
    else: agent_info_sub = None
    if len(exp_name_split) == 6:
        intest_scale = int(exp_name_split[5])
        storage_intest_number_array = [intest_scale, intest_scale] #####
    else: intest_scale = None
    ckpt_info_sub = args.checkpoint
    
    # Path setting
    val_data_path = args.exp_name + "/validation_scale" + ckpt_info_sub + "/"
    if not os.path.exists(val_data_path):
        os.makedirs(val_data_path)
        print("    CREATE DIRECTORY %s" %(val_data_path))
    else:
        print("    DIRECTORY ALREADY EXISTS %s" %(val_data_path))
        exit()
    train_default_path = "../PPO-POINT-TRAIN/" + agent_info + "_" + task_info + "_"
    if agent_info_sub != None: train_default_path += (agent_info_sub + "/")
    
    # Agent setting
    if agent_info == "ppo":
        actor_critic=core.MLPActorCritic_ppo_point_train
        agent = actor_critic(env_fn().observation_space,env_fn().action_space, **ac_kwargs)
        agent.eval() 
    else: exit()

    # Validation setting
    if val_info == "pofval": 
        storage_intest_number = storage_intest_number_array[0] + storage_intest_number_array[1]
        storage_intest_state_list = []
        storage_intest_label_list =[]
        safe_intest_index_list = [list() for _ in range(n_NN)]
        unsafe_intest_index_list = [list() for _ in range(n_NN)]
    else: exit() 



    
    # Run
    if val_info == "pofval":
        torch.multiprocessing.set_start_method('spawn')    
        start_time = time.time()
        scale_array = [1,2,5,10,20,50,100,200,500,1000] 
        n_scale = len(scale_array)
        n_total_process = n_scale
        num_cores = 10
        procs = []
        for thread_order in range(n_total_process):
            p = multiprocessing.Process(target=validation_for_one_process, args=(args.exp_name, thread_order, key_offset, agent, epochs, scale_array[thread_order%n_scale],
                                                                                default_action_margin, context, local_steps_per_epoch, render, hazard_check_margin, max_ep_len, val_data_path, ckpt_info_sub))
            p.start()
            procs.append(p)
            time.sleep(3)
            print("thread_order: ", thread_order)
            while (len(procs) >= num_cores): procs.pop(0).join()
        for p in procs:
            p.join()
    else: exit() 